{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8347485f",
   "metadata": {},
   "source": [
    "# Build an Video Action Classification System in 5 Minutes\n",
    "\n",
    "This notebook illustrates how to build a video classification system from scratch using [Towhee](https://towhee.io/). A video classification system classifies videos into pre-defined categories. This tutorial will use pretrained labels of human activities as example.\n",
    "\n",
    "Using the sample data of different classes of human activites, we will build a basic video classification system within 5 lines of code and check the performance using Towhee. In addition, this tutorial also suggests some optimization options. At the end, we use [Gradio](https://gradio.app/) to create a showcase that can be played with."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "569571ec",
   "metadata": {},
   "source": [
    "## Preparation\n",
    "\n",
    "### Install packages\n",
    "\n",
    "Make sure you have installed required python packages:\n",
    "\n",
    "| package |\n",
    "| -- |\n",
    "| towhee |\n",
    "| towhee.models |\n",
    "| pillow |\n",
    "| ipython |\n",
    "| gradio |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d2d8e3e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.0\u001b[0m\r\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\r\n"
     ]
    }
   ],
   "source": [
    "! python -m pip install -q towhee towhee.models pillow ipython gradio"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11ef6b1a",
   "metadata": {},
   "source": [
    "### Prepare data\n",
    "\n",
    "This tutorial will use a small data extracted from validation data of [Kinetics400](https://www.deepmind.com/open-source/kinetics). You can download the subset from [Github](https://github.com/towhee-io/data/releases/download/video-data/reverse_video_search.zip). This tutorial will just use 200 videos under `train` as example.\n",
    "\n",
    "The data is organized as follows:\n",
    "- **train:** 20 classes, 10 videos per class (200 in total)\n",
    "- **reverse_video_search.csv:** a csv file containing an ***id***, ***path***, and ***label*** for each video in train directory\n",
    "\n",
    "Let's take a quick look:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "54568b1a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0\n",
      "100  152M  100  152M    0     0  4455k      0  0:00:34  0:00:34 --:--:-- 5342k\n"
     ]
    }
   ],
   "source": [
    "! curl -L https://github.com/towhee-io/examples/releases/download/data/reverse_video_search.zip -O\n",
    "! unzip -q -o reverse_video_search.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "490d1379",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   id                                          path                 label\n",
      "0   0  ./train/country_line_dancing/bTbC3w_NIvM.mp4  country_line_dancing\n",
      "1   1  ./train/country_line_dancing/n2dWtEmNn5c.mp4  country_line_dancing\n",
      "2   2  ./train/country_line_dancing/zta-Iv-xK7I.mp4  country_line_dancing\n",
      "country_line_dancing     10\n",
      "pumping_fist             10\n",
      "playing_trombone         10\n",
      "shuffling_cards          10\n",
      "tap_dancing              10\n",
      "clay_pottery_making      10\n",
      "eating_hotdog            10\n",
      "eating_carrots           10\n",
      "juggling_soccer_ball     10\n",
      "juggling_fire            10\n",
      "javelin_throw            10\n",
      "dunking_basketball       10\n",
      "chopping_wood            10\n",
      "trimming_trees           10\n",
      "using_segway             10\n",
      "pushing_cart             10\n",
      "dancing_gangnam_style    10\n",
      "riding_mule              10\n",
      "drop_kicking             10\n",
      "doing_aerobics           10\n",
      "Name: label, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv('./reverse_video_search.csv')\n",
    "print(df.head(3))\n",
    "print(df.label.value_counts())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2171f2e7",
   "metadata": {},
   "source": [
    "For later steps to easier get videos & measure results, we build some helpful functions in advance:\n",
    "- **ground_truth:** get ground-truth label for the video by its path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dd1b0ef0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ground_truth(path):\n",
    "    label = df.set_index('path').at[path, 'label']\n",
    "    return [label.replace('_', ' ')]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a30aa33c",
   "metadata": {},
   "source": [
    "## Build System\n",
    "\n",
    "Now we are ready to build a video classification system using sample data. We will use the [X3D_M](https://arxiv.org/abs/2004.04730) model to predict most possible action labels for input videos. With proper [Towhee operators](https://towhee.io/operators), you don't need to go through video preprocessing & model details. It is very simple to use the [method-chaining style API](https://towhee.readthedocs.io/en/main/index.html) to wrap operators and then apply them to batch inputs."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d1da9c9",
   "metadata": {},
   "source": [
    "### Predict labels\n",
    "\n",
    "Let's take some 'tap_dancing' videos as example to see how to predict labels for videos within 5 lines. By default, the system will predict top 5 labels sorting by scores (of possibility) from high to low. You can control the number of labels returnbed by change `topk`. Please note that the first time run will take some time to download model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ea759389",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /home/junjie.jiangjjj/.cache/torch/hub/facebookresearch_pytorchvideo_main\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table style=\"border-collapse: collapse;\"><tr><th style=\"text-align: center; font-size: 130%; border: none;\">video_path</th> <th style=\"text-align: center; font-size: 130%; border: none;\">predicts</th> <th style=\"text-align: center; font-size: 130%; border: none;\">scores</th></tr> <tr><td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">./train/tap_dancing/2WdQuLmw-f4.mp4</td> <td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \"><br>tap dancing</br> <br>dancing charleston</br> <br>breakdancing</br> <br>jumpstyle dancing</br> <br>swing dancing</br></td> <td style=\"text-align: left; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">[0.00492,0.00324,0.00255,0.00253,...] len=5</td></tr> <tr><td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">./train/tap_dancing/Uf1PiOF8Poc.mp4</td> <td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \"><br>tap dancing</br> <br>dancing ballet</br> <br>country line dancing</br> <br>dancing charleston</br> <br>salsa dancing</br></td> <td style=\"text-align: left; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">[0.0045,0.00362,0.00256,0.0025,...] len=5</td></tr> <tr><td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">./train/tap_dancing/X7k8twydJIU.mp4</td> <td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \"><br>robot dancing</br> <br>tap dancing</br> <br>breakdancing</br> <br>krumping</br> <br>jumpstyle dancing</br></td> <td style=\"text-align: left; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">[0.00542,0.00279,0.00265,0.00255,...] len=5</td></tr> <tr><td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">./train/tap_dancing/PGPn8WhG3pM.mp4</td> <td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \"><br>tap dancing</br> <br>dancing charleston</br> <br>country line dancing</br> <br>jumpstyle dancing</br> <br>salsa dancing</br></td> <td style=\"text-align: left; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">[0.00578,0.0029,0.0025,0.00249,...] len=5</td></tr> <tr><td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">./train/tap_dancing/Krh21z_zyV8.mp4</td> <td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \"><br>tap dancing</br> <br>dancing ballet</br> <br>roller skating</br> <br>dancing charleston</br> <br>hopscotch</br></td> <td style=\"text-align: left; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">[0.00673,0.0025,0.00249,0.00249,...] len=5</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from towhee import ops, pipe, DataCollection\n",
    "from glob import glob\n",
    "\n",
    "p = (\n",
    "    pipe.input('video_path_root')\n",
    "    .flat_map('video_path_root', 'video_path', glob)\n",
    "    .map('video_path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))\n",
    "    .map('frames', ('predicts', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='x3d_m', skip_preprocess=True, topk=5))\n",
    "    .output('video_path', 'predicts', 'scores')\n",
    ")\n",
    "\n",
    "DataCollection(p('./train/tap_dancing/*.mp4')).show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b987a6e",
   "metadata": {},
   "source": [
    "#### Pipeline Explanation\n",
    "\n",
    "Here are some details for each line of the assemble pipeline:\n",
    "\n",
    "- `ops.video_decode.ffmpeg()`: an embeded Towhee operator reading video as frames with specified sample method and number of samples. [learn more](https://towhee.io/video-decode/ffmpeg)\n",
    "\n",
    "- `ops.action_classification.pytorchvideo()`: an embeded Towhee operator applying specified model to video frames, which can be used to predict labels and extract features for video. [learn more](https://towhee.io/action-classification/pytorchvideo)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b04bc12",
   "metadata": {},
   "source": [
    "### Evaluation\n",
    "\n",
    "We have just showed how to classify video, but how's its performance? \n",
    "In this section, we'll measure the performance with the average metric value:\n",
    "\n",
    "- **mHR (recall@K):**\n",
    "    - Mean Hit Ratio describes how many actual relevant results are returned out of all ground truths.\n",
    "    - Since we predict top K labels while only 1 ground truth for each entity, the mean hit ratio is equivalent to recall@topk."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "6038eee0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /home/junjie.jiangjjj/.cache/torch/hub/facebookresearch_pytorchvideo_main\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table style=\"border-collapse: collapse;\"><tr><th style=\"text-align: center; font-size: 130%; border: none;\">top1_mean_hit_ratio</th> <th style=\"text-align: center; font-size: 130%; border: none;\">top3_mean_hit_ratio</th> <th style=\"text-align: center; font-size: 130%; border: none;\">top5_mean_hit_ratio</th></tr> <tr><td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">0.7</td> <td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">0.875</td> <td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">0.9</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 34.97685408592224\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "def read_csv(csv_file):\n",
    "    import csv\n",
    "    with open(csv_file, 'r', encoding='utf-8-sig') as f:\n",
    "        data = csv.DictReader(f)\n",
    "        for line in data:\n",
    "            yield line['path']\n",
    "\n",
    "def mean_hit_ratio(actual, *predicteds):\n",
    "    rets = []\n",
    "    for predicted in predicteds:\n",
    "        ratios = []\n",
    "        for act, pre in zip(actual, predicted):\n",
    "            hit_num = len(set(act) & set(pre))\n",
    "            ratios.append(hit_num / len(act))\n",
    "        rets.append(sum(ratios) / len(ratios))\n",
    "    return rets\n",
    "\n",
    "p = (\n",
    "    pipe.input('csv_file')\n",
    "    .flat_map('csv_file', 'path', read_csv)\n",
    "    .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))\n",
    "    .map('frames', ('predicts', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='x3d_m', skip_preprocess=True, topk=5))\n",
    "    .map('predicts', ('top1', 'top3', 'top5'), lambda x: (x[:1], x[:3], x[:5]))\n",
    "    .map('path', 'ground_truth', ground_truth)\n",
    "    .window_all(('ground_truth', 'top1', 'top3', 'top5'), ('top1_mean_hit_ratio', 'top3_mean_hit_ratio', 'top5_mean_hit_ratio'),  mean_hit_ratio)\n",
    "    .output('top1_mean_hit_ratio', 'top3_mean_hit_ratio', 'top5_mean_hit_ratio')\n",
    ")\n",
    "\n",
    "import time\n",
    "start = time.time()\n",
    "DataCollection(p('reverse_video_search.csv')).show()\n",
    "end = time.time()\n",
    "print(f'Total time: {end-start}')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a80f350",
   "metadata": {},
   "source": [
    "## Optimization\n",
    "\n",
    "You're always encouraged to play around with the tutorial. We present some optimization options here to make improvements in accuracy, latency, and resource usage. With these methods, you can make the classification system better in performance and more feasible in production."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "924813d7",
   "metadata": {},
   "source": [
    "### Change model\n",
    "\n",
    "There are more video models using different networks. Normally a more complicated or larger model will show better results while cost more. You can always try more models to tradeoff among accuracy, latency, and resource usage. Here I show the performance of video classification using a SOTA model with [multiscale vision transformer](https://arxiv.org/abs/2104.11227) as backbone. The average recall increases by about 3% while double time is costed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "b10ba34e",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /home/junjie.jiangjjj/.cache/torch/hub/facebookresearch_pytorchvideo_main\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<table style=\"border-collapse: collapse;\"><tr><th style=\"text-align: center; font-size: 130%; border: none;\">top1_mean_hit_ratio</th> <th style=\"text-align: center; font-size: 130%; border: none;\">top3_mean_hit_ratio</th> <th style=\"text-align: center; font-size: 130%; border: none;\">top5_mean_hit_ratio</th></tr> <tr><td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">0.745</td> <td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">0.9</td> <td style=\"text-align: center; vertical-align: center; border-right: solid 1px #D3D3D3; border-left: solid 1px #D3D3D3; \">0.92</td></tr></table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total time: 36.61180877685547\n"
     ]
    }
   ],
   "source": [
    "\n",
    "p = (\n",
    "    pipe.input('csv_file')\n",
    "    .flat_map('csv_file', 'path', read_csv)\n",
    "    .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 32}))\n",
    "    .map('frames', ('predicts', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='mvit_base_32x3', skip_preprocess=True, topk=5))\n",
    "    .map('predicts', ('top1', 'top3', 'top5'), lambda x: (x[:1], x[:3], x[:5]))\n",
    "    .map('path', 'ground_truth', ground_truth)\n",
    "    .window_all(('ground_truth', 'top1', 'top3', 'top5'), ('top1_mean_hit_ratio', 'top3_mean_hit_ratio', 'top5_mean_hit_ratio'),  mean_hit_ratio)\n",
    "    .output('top1_mean_hit_ratio', 'top3_mean_hit_ratio', 'top5_mean_hit_ratio')\n",
    ")\n",
    "\n",
    "import time\n",
    "start = time.time()\n",
    "DataCollection(p('reverse_video_search.csv')).show()\n",
    "end = time.time()\n",
    "print(f'Total time: {end-start}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "607783a1",
   "metadata": {},
   "source": [
    "## Release a Showcase\n",
    "\n",
    "We can build a quick demo with this `action_classification_pipeline` with [Gradio](https://gradio.app/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a78c9ba6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /home/junjie.jiangjjj/.cache/torch/hub/facebookresearch_pytorchvideo_main\n"
     ]
    }
   ],
   "source": [
    "import gradio\n",
    "\n",
    "topk = 3\n",
    "\n",
    "def label(predicts: 'List', scores: 'List'):\n",
    "    labels = {}\n",
    "    for i in range(topk):\n",
    "        labels[predicts[i]] = scores[i]\n",
    "    return labels\n",
    "\n",
    "action_classification_pipe = (\n",
    "    pipe.input('video')\n",
    "    .map('video', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 32}))\n",
    "    .map('frames', ('predicts', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='mvit_base_32x3', skip_preprocess=True, topk=topk))\n",
    "    .map(('predicts', 'scores'), 'label', label)\n",
    "    .output('label')\n",
    ")\n",
    "\n",
    "def action_classification_function(video):\n",
    "    return action_classification_pipe(video).to_list()[0][0]\n",
    "    \n",
    "\n",
    "interface = gradio.Interface(action_classification_function, \n",
    "                             inputs=gradio.Video(source='upload'),\n",
    "                             outputs=[gradio.Label(num_top_classes=3)]\n",
    "                            )\n",
    "\n",
    "\n",
    "interface.launch(inline=True, share=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3ef0b4e",
   "metadata": {},
   "source": [
    "<img src='action_classification_demo.png' alt='action_classification_demo' width=700px/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e74dfea7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
